/* Copyright (c) 2023, Canaan Bright Sight Co., Ltd
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 1. Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
#include "face_emotion.h"
#include <vector>

//图像颜色集
cv::Scalar color_list_for_emo[] = {
    cv::Scalar(255, 255, 0),
    cv::Scalar(0, 255, 0),
    cv::Scalar(50, 220, 255),
    cv::Scalar(255, 0, 255),
    cv::Scalar(0, 0, 255),
    cv::Scalar(0, 170, 255),
    cv::Scalar(0, 255, 255)};

cv::Scalar color_list_for_osd_emo[] = {
    cv::Scalar(0, 255, 255, 255),
    cv::Scalar(0, 255, 0, 255),
    cv::Scalar(255, 220, 50, 255),
    cv::Scalar(255, 0, 255, 255),
    cv::Scalar(255, 0, 0, 255),
    cv::Scalar(255, 170, 0, 255),
    cv::Scalar(255, 255, 0, 255)
};



FaceEmotion::FaceEmotion(char *kmodel_file, FrameCHWSize image_size, int debug_mode) : AIBase(kmodel_file,"FaceEmotion", debug_mode)
{
    model_name_ = "FaceEmotion";
	label_list_ = {"Anger","Disgust","Fear","Happiness","Neutral","Sadness","Surprise"};    
    image_size_=image_size;
    input_size_ = {input_shapes_[0][1], input_shapes_[0][2], input_shapes_[0][3]};
    ai2d_out_tensor_ = get_input_tensor(0);
}

FaceEmotion::~FaceEmotion()
{
}

void FaceEmotion::pre_process(runtime_tensor &input_tensor, float* sparse_points){
	ScopedTiming st(model_name_ + " pre_process", debug_mode_);
	get_affine_matrix(sparse_points);
	Utils::affine_set(image_size_, input_size_,ai2d_builder_, matrix_dst_);
	ai2d_builder_->invoke(input_tensor,ai2d_out_tensor_).expect("error occurred in ai2d running");
}


void FaceEmotion::inference()
{
    this->run();
    this->get_output();
}

void FaceEmotion::post_process(FaceEmotionInfo& result)
{
	ScopedTiming st(model_name_ + " post_process", debug_mode_);
	vector<float> pred(p_outputs_[0],p_outputs_[0]+output_shapes_[0][0]*output_shapes_[0][1]);
	vector<float> softmax_pred;
    softmax(pred,softmax_pred);
	auto max_score = std::max_element(softmax_pred.begin(), softmax_pred.end());
	int argmax = std::distance(softmax_pred.begin(), max_score);
	result.idx = argmax;
	result.label = label_list_[argmax];
	result.score = *max_score;
}

void FaceEmotion::draw_result(cv::Mat& src_img,Bbox& bbox,FaceEmotionInfo& result, bool pic_mode)
{
    int src_w = src_img.cols;
    int src_h = src_img.rows;
    int max_src_size = std::max(src_w,src_h);

    char text[30];
	sprintf(text, "%s",result.label.c_str());

	if(pic_mode)
    {
        cv::rectangle(src_img, cv::Rect(bbox.x, bbox.y , bbox.w, bbox.h), cv::Scalar(255, 255, 255), 2, 2, 0);
        cv::putText(src_img, text , {bbox.x,std::max(int(bbox.y-10),0)}, cv::FONT_HERSHEY_COMPLEX, 0.5, cv::Scalar(255, 0, 0), 1, 8, 0);
    }
    else
    {
        int x = bbox.x / image_size_.width * src_w;
        int y = bbox.y / image_size_.height * src_h;
        int w = bbox.w / image_size_.width * src_w;
        int h = bbox.h / image_size_.height * src_h;
        cv::rectangle(src_img, cv::Rect(x, y , w, h), cv::Scalar(255,255, 255, 255), 2, 2, 0);
		cv::putText(src_img, text , {x,std::max(int(y-10),0)}, cv::FONT_HERSHEY_COMPLEX, 2, color_list_for_osd_emo[result.idx], 2, 8, 0);
	} 
}

void FaceEmotion::svd22(const float a[4], float u[4], float s[2], float v[4])
{
	s[0] = (sqrtf(powf(a[0] - a[3], 2) + powf(a[1] + a[2], 2)) + sqrtf(powf(a[0] + a[3], 2) + powf(a[1] - a[2], 2))) / 2;
	s[1] = fabsf(s[0] - sqrtf(powf(a[0] - a[3], 2) + powf(a[1] + a[2], 2)));
	v[2] = (s[0] > s[1]) ? sinf((atan2f(2 * (a[0] * a[1] + a[2] * a[3]), a[0] * a[0] - a[1] * a[1] + a[2] * a[2] - a[3] * a[3])) / 2) : 0;
	v[0] = sqrtf(1 - v[2] * v[2]);
	v[1] = -v[2];
	v[3] = v[0];
	u[0] = (s[0] != 0) ? -(a[0] * v[0] + a[1] * v[2]) / s[0] : 1;
	u[2] = (s[0] != 0) ? -(a[2] * v[0] + a[3] * v[2]) / s[0] : 0;
	u[1] = (s[1] != 0) ? (a[0] * v[1] + a[1] * v[3]) / s[1] : -u[2];
	u[3] = (s[1] != 0) ? (a[2] * v[1] + a[3] * v[3]) / s[1] : u[0];
	v[0] = -v[0];
	v[2] = -v[2];
}

 //在224分辨率的标准五官点
static float umeyama_args_224[] =
{
#define PIC_SIZE 224
	38.2946 * PIC_SIZE / 112,  51.6963 * PIC_SIZE / 112,
	73.5318 * PIC_SIZE / 112, 51.5014 * PIC_SIZE / 112,
	56.0252 * PIC_SIZE / 112, 71.7366 * PIC_SIZE / 112,
	41.5493 * PIC_SIZE / 112, 92.3655 * PIC_SIZE / 112,
	70.7299 * PIC_SIZE / 112, 92.2041 * PIC_SIZE / 112
};

void FaceEmotion::image_umeyama_224(float* src, float* dst)
{
#define SRC_NUM 5
#define SRC_DIM 2
	int i, j, k;
	float src_mean[SRC_DIM] = { 0.0 };
	float dst_mean[SRC_DIM] = { 0.0 };
	for (i = 0; i < SRC_NUM * 2; i += 2)
	{
		src_mean[0] += src[i];
		src_mean[1] += src[i + 1];
		dst_mean[0] += umeyama_args_224[i];
		dst_mean[1] += umeyama_args_224[i + 1];
	}
	src_mean[0] /= SRC_NUM;
	src_mean[1] /= SRC_NUM;
	dst_mean[0] /= SRC_NUM;
	dst_mean[1] /= SRC_NUM;

	float src_demean[SRC_NUM][2] = { 0.0 };
	float dst_demean[SRC_NUM][2] = { 0.0 };

	for (i = 0; i < SRC_NUM; i++)
	{
		src_demean[i][0] = src[2 * i] - src_mean[0];
		src_demean[i][1] = src[2 * i + 1] - src_mean[1];
		dst_demean[i][0] = umeyama_args_224[2 * i] - dst_mean[0];
		dst_demean[i][1] = umeyama_args_224[2 * i + 1] - dst_mean[1];
	}

	float A[SRC_DIM][SRC_DIM] = { 0.0 };
	for (i = 0; i < SRC_DIM; i++)
	{
		for (k = 0; k < SRC_DIM; k++)
		{
			for (j = 0; j < SRC_NUM; j++)
			{
				A[i][k] += dst_demean[j][i] * src_demean[j][k];
			}
			A[i][k] /= SRC_NUM;
		}
	}

	float(*T)[SRC_DIM + 1] = (float(*)[SRC_DIM + 1])dst;
	T[0][0] = 1;
	T[0][1] = 0;
	T[0][2] = 0;
	T[1][0] = 0;
	T[1][1] = 1;
	T[1][2] = 0;
	T[2][0] = 0;
	T[2][1] = 0;
	T[2][2] = 1;

	float U[SRC_DIM][SRC_DIM] = { 0 };
	float S[SRC_DIM] = { 0 };
	float V[SRC_DIM][SRC_DIM] = { 0 };
	svd22(&A[0][0], &U[0][0], S, &V[0][0]);

	T[0][0] = U[0][0] * V[0][0] + U[0][1] * V[1][0];
	T[0][1] = U[0][0] * V[0][1] + U[0][1] * V[1][1];
	T[1][0] = U[1][0] * V[0][0] + U[1][1] * V[1][0];
	T[1][1] = U[1][0] * V[0][1] + U[1][1] * V[1][1];

	float scale = 1.0;
	float src_demean_mean[SRC_DIM] = { 0.0 };
	float src_demean_var[SRC_DIM] = { 0.0 };
	for (i = 0; i < SRC_NUM; i++)
	{
		src_demean_mean[0] += src_demean[i][0];
		src_demean_mean[1] += src_demean[i][1];
	}
	src_demean_mean[0] /= SRC_NUM;
	src_demean_mean[1] /= SRC_NUM;

	for (i = 0; i < SRC_NUM; i++)
	{
		src_demean_var[0] += (src_demean_mean[0] - src_demean[i][0]) * (src_demean_mean[0] - src_demean[i][0]);
		src_demean_var[1] += (src_demean_mean[1] - src_demean[i][1]) * (src_demean_mean[1] - src_demean[i][1]);
	}
	src_demean_var[0] /= (SRC_NUM);
	src_demean_var[1] /= (SRC_NUM);
	scale = 1.0 / (src_demean_var[0] + src_demean_var[1]) * (S[0] + S[1]);
	T[0][2] = dst_mean[0] - scale * (T[0][0] * src_mean[0] + T[0][1] * src_mean[1]);
	T[1][2] = dst_mean[1] - scale * (T[1][0] * src_mean[0] + T[1][1] * src_mean[1]);
	T[0][0] *= scale;
	T[0][1] *= scale;
	T[1][0] *= scale;
	T[1][1] *= scale;
	float(*TT)[3] = (float(*)[3])T;
}

void FaceEmotion::softmax(vector<float>& input,vector<float>& output)
{
    //e_x = np.exp(x - np.max(x))
    //return e_x / e_x.sum()
    std::vector<float>::iterator p_input_max = std::max_element(input.begin(), input.end());
    float input_max = *p_input_max;
    float input_total = 0;
    
    for(auto x:input)
	{
		input_total+=exp( x- input_max);
	}

    output.resize(input.size());
	for(int i=0;i<input.size();++i)
	{
		output[i] = exp(input[i] - input_max)/input_total;
	}
}

void FaceEmotion::get_affine_matrix(float* sparse_points)
{
    float matrix_src[5][2];
    for (uint32_t i = 0; i < 5; ++i)
    {
        matrix_src[i][0] = sparse_points[2 * i + 0];
		matrix_src[i][1] = sparse_points[2 * i + 1];
    }
    image_umeyama_224(&matrix_src[0][0], &matrix_dst_[0]);
}